{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Evaluate DFR models on all ImageNet variations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 73,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "import numpy as np\n",
    "import tqdm\n",
    "import pickle\n",
    "\n",
    "from matplotlib import pyplot as plt\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 74,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_embeddings(path, mean=None, std=None):\n",
    "    arr = np.load(path)\n",
    "    x, y = arr[\"embeddings\"], arr[\"labels\"]\n",
    "    if mean is not None:\n",
    "        x = (x - mean) / std\n",
    "    return x, y"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Change data paths here."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 75,
   "metadata": {},
   "outputs": [],
   "source": [
    "imagenet_c_corruptions = [\"brightness\", \"defocus_blur\", \"fog\", \"gaussian_blur\", \"glass_blur\",\n",
    "                          \"jpeg_compression\", \"pixelate\", \"shot_noise\", \"spatter\", \"zoom_blur\",\n",
    "                          \"contrast\", \"elastic_transform\", \"frost\", \"gaussian_noise\",\n",
    "                          \"impulse_noise\", \"motion_blur\", \"saturate\", \"snow\", \"speckle_noise\"]\n",
    "intensities = [1, 2, 3, 4, 5]\n",
    "\n",
    "eval_path_dict = {\n",
    "    \"imagenet_r\": \"/datasets/imagenet-r/imagenet-r_resnet50_val_embeddings.npz\",\n",
    "    \"imagenet_a\": \"/datasets/imagenet-a/imagenet-a_resnet50_val_embeddings.npz\",\n",
    "    \"imagenet\": \"/datasets/imagenet_symlink/resnet50_val_embeddings.npz\",\n",
    "    \"imagenet_stylized\": \"/datasets/imagenet-stylized/imagenet_resnet50_val_embeddings.npz\",\n",
    "}\n",
    "eval_path_dict.update({\n",
    "    f\"imagenet_c_{corruption}_{intensity}\": \n",
    "        f\"/datasets/imagenet-c/{corruption}/{intensity}/imagenet-c_resnet50_val_embeddings.npz\"\n",
    "    for corruption in imagenet_c_corruptions for intensity in intensities\n",
    "})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## DFR-IN-SIN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 78,
   "metadata": {},
   "outputs": [],
   "source": [
    "arr = np.load(\"dfr_insin_weights_bs10k.npz\")\n",
    "w = arr[\"w\"]\n",
    "b = arr[\"b\"]\n",
    "preprocess_mean = arr[\"preprocess_mean\"]\n",
    "preprocess_std = arr[\"preprocess_std\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 79,
   "metadata": {},
   "outputs": [],
   "source": [
    "eval_datasets = {k: load_embeddings(p, mean=preprocess_mean, std=preprocess_std)\n",
    "                 for k, p in eval_path_dict.items()}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 80,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 99/99 [04:38<00:00,  2.81s/it]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'imagenet_r': (0.27166666666666667, 0.42436666666666667),\n",
       " 'imagenet_a': (0.0024, 0.039466666666666664),\n",
       " 'imagenet': (0.74524, 0.9178),\n",
       " 'imagenet_stylized': (0.21418, 0.39042),\n",
       " 'imagenet_c_brightness_1': (0.69676, 0.89084),\n",
       " 'imagenet_c_brightness_2': (0.68142, 0.87932),\n",
       " 'imagenet_c_brightness_3': (0.65608, 0.86096),\n",
       " 'imagenet_c_brightness_4': (0.6165, 0.83204),\n",
       " 'imagenet_c_brightness_5': (0.56082, 0.7897),\n",
       " 'imagenet_c_defocus_blur_1': (0.5479, 0.78022),\n",
       " 'imagenet_c_defocus_blur_2': (0.4714, 0.71306),\n",
       " 'imagenet_c_defocus_blur_3': (0.32916, 0.55972),\n",
       " 'imagenet_c_defocus_blur_4': (0.2244, 0.42184),\n",
       " 'imagenet_c_defocus_blur_5': (0.14804, 0.30394),\n",
       " 'imagenet_c_fog_1': (0.5928, 0.8149),\n",
       " 'imagenet_c_fog_2': (0.54144, 0.77306),\n",
       " 'imagenet_c_fog_3': (0.46918, 0.70442),\n",
       " 'imagenet_c_fog_4': (0.42214, 0.65012),\n",
       " 'imagenet_c_fog_5': (0.2852, 0.48892),\n",
       " 'imagenet_c_gaussian_blur_1': (0.63952, 0.8516),\n",
       " 'imagenet_c_gaussian_blur_2': (0.50584, 0.74432),\n",
       " 'imagenet_c_gaussian_blur_3': (0.36596, 0.60416),\n",
       " 'imagenet_c_gaussian_blur_4': (0.25314, 0.46306),\n",
       " 'imagenet_c_gaussian_blur_5': (0.11712, 0.25352),\n",
       " 'imagenet_c_glass_blur_1': (0.52564, 0.75754),\n",
       " 'imagenet_c_glass_blur_2': (0.39982, 0.63056),\n",
       " 'imagenet_c_glass_blur_3': (0.18652, 0.36286),\n",
       " 'imagenet_c_glass_blur_4': (0.13852, 0.28622),\n",
       " 'imagenet_c_glass_blur_5': (0.0969, 0.21526),\n",
       " 'imagenet_c_jpeg_compression_1': (0.63082, 0.84532),\n",
       " 'imagenet_c_jpeg_compression_2': (0.5978, 0.82056),\n",
       " 'imagenet_c_jpeg_compression_3': (0.57044, 0.79726),\n",
       " 'imagenet_c_jpeg_compression_4': (0.48156, 0.71328),\n",
       " 'imagenet_c_jpeg_compression_5': (0.35616, 0.5746),\n",
       " 'imagenet_c_pixelate_1': (0.62394, 0.84136),\n",
       " 'imagenet_c_pixelate_2': (0.59816, 0.82088),\n",
       " 'imagenet_c_pixelate_3': (0.4925, 0.72956),\n",
       " 'imagenet_c_pixelate_4': (0.34308, 0.56556),\n",
       " 'imagenet_c_pixelate_5': (0.24734, 0.43948),\n",
       " 'imagenet_c_shot_noise_1': (0.55686, 0.7783),\n",
       " 'imagenet_c_shot_noise_2': (0.44384, 0.67572),\n",
       " 'imagenet_c_shot_noise_3': (0.31784, 0.53692),\n",
       " 'imagenet_c_shot_noise_4': (0.15678, 0.3157),\n",
       " 'imagenet_c_shot_noise_5': (0.0811, 0.18662),\n",
       " 'imagenet_c_spatter_1': (0.68556, 0.88068),\n",
       " 'imagenet_c_spatter_2': (0.58046, 0.79608),\n",
       " 'imagenet_c_spatter_3': (0.47572, 0.70394),\n",
       " 'imagenet_c_spatter_4': (0.39638, 0.62472),\n",
       " 'imagenet_c_spatter_5': (0.28844, 0.49614),\n",
       " 'imagenet_c_zoom_blur_1': (0.4951, 0.72066),\n",
       " 'imagenet_c_zoom_blur_2': (0.40486, 0.62574),\n",
       " 'imagenet_c_zoom_blur_3': (0.33666, 0.55014),\n",
       " 'imagenet_c_zoom_blur_4': (0.27526, 0.47362),\n",
       " 'imagenet_c_zoom_blur_5': (0.22088, 0.39852),\n",
       " 'imagenet_c_contrast_1': (0.61614, 0.83372),\n",
       " 'imagenet_c_contrast_2': (0.55908, 0.7903),\n",
       " 'imagenet_c_contrast_3': (0.44966, 0.68958),\n",
       " 'imagenet_c_contrast_4': (0.21644, 0.41226),\n",
       " 'imagenet_c_contrast_5': (0.05736, 0.13804),\n",
       " 'imagenet_c_elastic_transform_1': (0.63726, 0.84842),\n",
       " 'imagenet_c_elastic_transform_2': (0.43196, 0.6417),\n",
       " 'imagenet_c_elastic_transform_3': (0.54456, 0.76832),\n",
       " 'imagenet_c_elastic_transform_4': (0.42198, 0.64846),\n",
       " 'imagenet_c_elastic_transform_5': (0.1854, 0.35736),\n",
       " 'imagenet_c_frost_1': (0.58326, 0.80026),\n",
       " 'imagenet_c_frost_2': (0.4388, 0.66388),\n",
       " 'imagenet_c_frost_3': (0.3386, 0.55008),\n",
       " 'imagenet_c_frost_4': (0.3211, 0.52996),\n",
       " 'imagenet_c_frost_5': (0.26282, 0.45538),\n",
       " 'imagenet_c_gaussian_noise_1': (0.56938, 0.78956),\n",
       " 'imagenet_c_gaussian_noise_2': (0.47184, 0.70204),\n",
       " 'imagenet_c_gaussian_noise_3': (0.33388, 0.55692),\n",
       " 'imagenet_c_gaussian_noise_4': (0.19394, 0.3711),\n",
       " 'imagenet_c_gaussian_noise_5': (0.07034, 0.16876),\n",
       " 'imagenet_c_impulse_noise_1': (0.47054, 0.70084),\n",
       " 'imagenet_c_impulse_noise_2': (0.39548, 0.6256),\n",
       " 'imagenet_c_impulse_noise_3': (0.3332, 0.55504),\n",
       " 'imagenet_c_impulse_noise_4': (0.18998, 0.36958),\n",
       " 'imagenet_c_impulse_noise_5': (0.0787, 0.18638),\n",
       " 'imagenet_c_motion_blur_1': (0.60332, 0.82138),\n",
       " 'imagenet_c_motion_blur_2': (0.499, 0.73164),\n",
       " 'imagenet_c_motion_blur_3': (0.34266, 0.56304),\n",
       " 'imagenet_c_motion_blur_4': (0.20158, 0.37946),\n",
       " 'imagenet_c_motion_blur_5': (0.13566, 0.27952),\n",
       " 'imagenet_c_saturate_1': (0.64428, 0.85892),\n",
       " 'imagenet_c_saturate_2': (0.61142, 0.83508),\n",
       " 'imagenet_c_saturate_3': (0.67602, 0.87718),\n",
       " 'imagenet_c_saturate_4': (0.58828, 0.81322),\n",
       " 'imagenet_c_saturate_5': (0.48866, 0.72246),\n",
       " 'imagenet_c_snow_1': (0.52446, 0.74868),\n",
       " 'imagenet_c_snow_2': (0.32946, 0.542),\n",
       " 'imagenet_c_snow_3': (0.34678, 0.55758),\n",
       " 'imagenet_c_snow_4': (0.24668, 0.43258),\n",
       " 'imagenet_c_snow_5': (0.19534, 0.36134),\n",
       " 'imagenet_c_speckle_noise_1': (0.58028, 0.7989),\n",
       " 'imagenet_c_speckle_noise_2': (0.5211, 0.74908),\n",
       " 'imagenet_c_speckle_noise_3': (0.36062, 0.58558),\n",
       " 'imagenet_c_speckle_noise_4': (0.27448, 0.48324),\n",
       " 'imagenet_c_speckle_noise_5': (0.18824, 0.36)}"
      ]
     },
     "execution_count": 80,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "results = {}\n",
    "for key, (x, y) in tqdm.tqdm(eval_datasets.items()):\n",
    "    preds = x @ w.T + b[None, :]\n",
    "    preds_cls = np.argsort(preds, -1)\n",
    "    preds1 = preds_cls[:, -1:]\n",
    "    preds5 = preds_cls[:, -5:]\n",
    "    acc1 = (preds1 == y[:, None]).sum() / len(y)\n",
    "    acc5 = (preds5 == y[:, None]).sum() / len(y)\n",
    "    results[key] = (acc1, acc5)\n",
    "results"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## DFR-SIN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 86,
   "metadata": {},
   "outputs": [],
   "source": [
    "arr = np.load(\"dfr_sin_weights_bs10k.npz\")\n",
    "w = arr[\"w\"]\n",
    "b = arr[\"b\"]\n",
    "preprocess_mean = arr[\"preprocess_mean\"]\n",
    "preprocess_std = arr[\"preprocess_std\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 87,
   "metadata": {},
   "outputs": [],
   "source": [
    "eval_datasets = {k: load_embeddings(p, mean=preprocess_mean, std=preprocess_std)\n",
    "                 for k, p in eval_path_dict.items()}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 88,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 99/99 [04:26<00:00,  2.69s/it]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'imagenet_r': (0.24566666666666667, 0.3939666666666667),\n",
       " 'imagenet_a': (0.004, 0.04),\n",
       " 'imagenet': (0.65076, 0.85662),\n",
       " 'imagenet_stylized': (0.21952, 0.39462),\n",
       " 'imagenet_c_brightness_1': (0.62642, 0.83606),\n",
       " 'imagenet_c_brightness_2': (0.60964, 0.82278),\n",
       " 'imagenet_c_brightness_3': (0.58636, 0.80426),\n",
       " 'imagenet_c_brightness_4': (0.5512, 0.77514),\n",
       " 'imagenet_c_brightness_5': (0.50192, 0.73456),\n",
       " 'imagenet_c_defocus_blur_1': (0.49322, 0.72342),\n",
       " 'imagenet_c_defocus_blur_2': (0.42214, 0.6562),\n",
       " 'imagenet_c_defocus_blur_3': (0.28902, 0.50306),\n",
       " 'imagenet_c_defocus_blur_4': (0.19266, 0.3718),\n",
       " 'imagenet_c_defocus_blur_5': (0.125, 0.2662),\n",
       " 'imagenet_c_fog_1': (0.54142, 0.76494),\n",
       " 'imagenet_c_fog_2': (0.4977, 0.72516),\n",
       " 'imagenet_c_fog_3': (0.43576, 0.66294),\n",
       " 'imagenet_c_fog_4': (0.39164, 0.61172),\n",
       " 'imagenet_c_fog_5': (0.26734, 0.46046),\n",
       " 'imagenet_c_gaussian_blur_1': (0.57676, 0.79614),\n",
       " 'imagenet_c_gaussian_blur_2': (0.4556, 0.688),\n",
       " 'imagenet_c_gaussian_blur_3': (0.3344, 0.5578),\n",
       " 'imagenet_c_gaussian_blur_4': (0.22988, 0.42664),\n",
       " 'imagenet_c_gaussian_blur_5': (0.10314, 0.22652),\n",
       " 'imagenet_c_glass_blur_1': (0.47884, 0.70622),\n",
       " 'imagenet_c_glass_blur_2': (0.36464, 0.58582),\n",
       " 'imagenet_c_glass_blur_3': (0.17246, 0.33178),\n",
       " 'imagenet_c_glass_blur_4': (0.12718, 0.26032),\n",
       " 'imagenet_c_glass_blur_5': (0.08386, 0.18922),\n",
       " 'imagenet_c_jpeg_compression_1': (0.57746, 0.79562),\n",
       " 'imagenet_c_jpeg_compression_2': (0.5486, 0.77068),\n",
       " 'imagenet_c_jpeg_compression_3': (0.52528, 0.74844),\n",
       " 'imagenet_c_jpeg_compression_4': (0.4416, 0.6662),\n",
       " 'imagenet_c_jpeg_compression_5': (0.32456, 0.53218),\n",
       " 'imagenet_c_pixelate_1': (0.56694, 0.78798),\n",
       " 'imagenet_c_pixelate_2': (0.54396, 0.767),\n",
       " 'imagenet_c_pixelate_3': (0.4478, 0.67696),\n",
       " 'imagenet_c_pixelate_4': (0.3073, 0.51562),\n",
       " 'imagenet_c_pixelate_5': (0.21564, 0.3917),\n",
       " 'imagenet_c_shot_noise_1': (0.4954, 0.71966),\n",
       " 'imagenet_c_shot_noise_2': (0.39166, 0.6174),\n",
       " 'imagenet_c_shot_noise_3': (0.28234, 0.48744),\n",
       " 'imagenet_c_shot_noise_4': (0.13942, 0.28344),\n",
       " 'imagenet_c_shot_noise_5': (0.07378, 0.16646),\n",
       " 'imagenet_c_spatter_1': (0.61844, 0.82942),\n",
       " 'imagenet_c_spatter_2': (0.53068, 0.75062),\n",
       " 'imagenet_c_spatter_3': (0.4378, 0.66214),\n",
       " 'imagenet_c_spatter_4': (0.36506, 0.58216),\n",
       " 'imagenet_c_spatter_5': (0.27226, 0.46338),\n",
       " 'imagenet_c_zoom_blur_1': (0.43632, 0.6612),\n",
       " 'imagenet_c_zoom_blur_2': (0.35466, 0.5703),\n",
       " 'imagenet_c_zoom_blur_3': (0.29802, 0.50264),\n",
       " 'imagenet_c_zoom_blur_4': (0.24108, 0.4287),\n",
       " 'imagenet_c_zoom_blur_5': (0.19092, 0.35858),\n",
       " 'imagenet_c_contrast_1': (0.55732, 0.77904),\n",
       " 'imagenet_c_contrast_2': (0.5065, 0.73462),\n",
       " 'imagenet_c_contrast_3': (0.41106, 0.64238),\n",
       " 'imagenet_c_contrast_4': (0.20344, 0.38206),\n",
       " 'imagenet_c_contrast_5': (0.05514, 0.13004),\n",
       " 'imagenet_c_elastic_transform_1': (0.57022, 0.79206),\n",
       " 'imagenet_c_elastic_transform_2': (0.39184, 0.59836),\n",
       " 'imagenet_c_elastic_transform_3': (0.50334, 0.72846),\n",
       " 'imagenet_c_elastic_transform_4': (0.39528, 0.61504),\n",
       " 'imagenet_c_elastic_transform_5': (0.17602, 0.34176),\n",
       " 'imagenet_c_frost_1': (0.53008, 0.7499),\n",
       " 'imagenet_c_frost_2': (0.40906, 0.62622),\n",
       " 'imagenet_c_frost_3': (0.31854, 0.52154),\n",
       " 'imagenet_c_frost_4': (0.3048, 0.50438),\n",
       " 'imagenet_c_frost_5': (0.2507, 0.43488),\n",
       " 'imagenet_c_gaussian_noise_1': (0.5033, 0.72728),\n",
       " 'imagenet_c_gaussian_noise_2': (0.41456, 0.64078),\n",
       " 'imagenet_c_gaussian_noise_3': (0.29562, 0.50532),\n",
       " 'imagenet_c_gaussian_noise_4': (0.17012, 0.33226),\n",
       " 'imagenet_c_gaussian_noise_5': (0.06196, 0.15008),\n",
       " 'imagenet_c_impulse_noise_1': (0.41644, 0.64484),\n",
       " 'imagenet_c_impulse_noise_2': (0.35068, 0.57414),\n",
       " 'imagenet_c_impulse_noise_3': (0.2962, 0.50608),\n",
       " 'imagenet_c_impulse_noise_4': (0.16758, 0.3335),\n",
       " 'imagenet_c_impulse_noise_5': (0.06922, 0.16562),\n",
       " 'imagenet_c_motion_blur_1': (0.53586, 0.75996),\n",
       " 'imagenet_c_motion_blur_2': (0.43668, 0.6666),\n",
       " 'imagenet_c_motion_blur_3': (0.29796, 0.5042),\n",
       " 'imagenet_c_motion_blur_4': (0.17106, 0.33188),\n",
       " 'imagenet_c_motion_blur_5': (0.11364, 0.24264),\n",
       " 'imagenet_c_saturate_1': (0.58064, 0.80384),\n",
       " 'imagenet_c_saturate_2': (0.54698, 0.7783),\n",
       " 'imagenet_c_saturate_3': (0.60354, 0.82164),\n",
       " 'imagenet_c_saturate_4': (0.5229, 0.75726),\n",
       " 'imagenet_c_saturate_5': (0.43602, 0.66914),\n",
       " 'imagenet_c_snow_1': (0.47662, 0.6972),\n",
       " 'imagenet_c_snow_2': (0.30088, 0.49918),\n",
       " 'imagenet_c_snow_3': (0.31424, 0.51292),\n",
       " 'imagenet_c_snow_4': (0.22508, 0.39834),\n",
       " 'imagenet_c_snow_5': (0.1804, 0.33582),\n",
       " 'imagenet_c_speckle_noise_1': (0.51788, 0.74208),\n",
       " 'imagenet_c_speckle_noise_2': (0.46516, 0.69364),\n",
       " 'imagenet_c_speckle_noise_3': (0.31974, 0.53572),\n",
       " 'imagenet_c_speckle_noise_4': (0.24372, 0.43808),\n",
       " 'imagenet_c_speckle_noise_5': (0.16562, 0.32372)}"
      ]
     },
     "execution_count": 88,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "results = {}\n",
    "for key, (x, y) in tqdm.tqdm(eval_datasets.items()):\n",
    "    preds = x @ w.T + b[None, :]\n",
    "    preds_cls = np.argsort(preds, -1)\n",
    "    preds1 = preds_cls[:, -1:]\n",
    "    preds5 = preds_cls[:, -5:]\n",
    "    acc1 = (preds1 == y[:, None]).sum() / len(y)\n",
    "    acc5 = (preds5 == y[:, None]).sum() / len(y)\n",
    "    results[key] = (acc1, acc5)\n",
    "results"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## DFR-IN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 90,
   "metadata": {},
   "outputs": [],
   "source": [
    "arr = np.load(\"dfr_in_weights_bs10k.npz\")\n",
    "w = arr[\"w\"]\n",
    "b = arr[\"b\"]\n",
    "preprocess_mean = arr[\"preprocess_mean\"]\n",
    "preprocess_std = arr[\"preprocess_std\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 91,
   "metadata": {},
   "outputs": [],
   "source": [
    "eval_datasets = {k: load_embeddings(p, mean=preprocess_mean, std=preprocess_std)\n",
    "                 for k, p in eval_path_dict.items()}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 92,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 99/99 [07:11<00:00,  4.36s/it]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'imagenet_r': (0.2287, 0.3642),\n",
       " 'imagenet_a': (0.0004, 0.0324),\n",
       " 'imagenet': (0.75224, 0.92474),\n",
       " 'imagenet_stylized': (0.06264, 0.1439),\n",
       " 'imagenet_c_brightness_1': (0.69752, 0.89402),\n",
       " 'imagenet_c_brightness_2': (0.6761, 0.88078),\n",
       " 'imagenet_c_brightness_3': (0.64328, 0.85982),\n",
       " 'imagenet_c_brightness_4': (0.59238, 0.8237),\n",
       " 'imagenet_c_brightness_5': (0.52192, 0.76648),\n",
       " 'imagenet_c_defocus_blur_1': (0.53122, 0.77282),\n",
       " 'imagenet_c_defocus_blur_2': (0.45026, 0.69792),\n",
       " 'imagenet_c_defocus_blur_3': (0.30214, 0.53054),\n",
       " 'imagenet_c_defocus_blur_4': (0.19806, 0.38572),\n",
       " 'imagenet_c_defocus_blur_5': (0.1242, 0.26832),\n",
       " 'imagenet_c_fog_1': (0.55736, 0.7949),\n",
       " 'imagenet_c_fog_2': (0.49118, 0.73648),\n",
       " 'imagenet_c_fog_3': (0.40022, 0.64656),\n",
       " 'imagenet_c_fog_4': (0.33862, 0.57352),\n",
       " 'imagenet_c_fog_5': (0.19868, 0.38268),\n",
       " 'imagenet_c_gaussian_blur_1': (0.6289, 0.85092),\n",
       " 'imagenet_c_gaussian_blur_2': (0.48582, 0.73352),\n",
       " 'imagenet_c_gaussian_blur_3': (0.3359, 0.57272),\n",
       " 'imagenet_c_gaussian_blur_4': (0.22072, 0.4212),\n",
       " 'imagenet_c_gaussian_blur_5': (0.09896, 0.2202),\n",
       " 'imagenet_c_glass_blur_1': (0.49738, 0.73468),\n",
       " 'imagenet_c_glass_blur_2': (0.35814, 0.5842),\n",
       " 'imagenet_c_glass_blur_3': (0.14532, 0.2963),\n",
       " 'imagenet_c_glass_blur_4': (0.10392, 0.22596),\n",
       " 'imagenet_c_glass_blur_5': (0.07048, 0.16738),\n",
       " 'imagenet_c_jpeg_compression_1': (0.61434, 0.84014),\n",
       " 'imagenet_c_jpeg_compression_2': (0.57828, 0.8099),\n",
       " 'imagenet_c_jpeg_compression_3': (0.54776, 0.78324),\n",
       " 'imagenet_c_jpeg_compression_4': (0.45006, 0.68624),\n",
       " 'imagenet_c_jpeg_compression_5': (0.31386, 0.52974),\n",
       " 'imagenet_c_pixelate_1': (0.6086, 0.8363),\n",
       " 'imagenet_c_pixelate_2': (0.5795, 0.8144),\n",
       " 'imagenet_c_pixelate_3': (0.45914, 0.70356),\n",
       " 'imagenet_c_pixelate_4': (0.30168, 0.5203),\n",
       " 'imagenet_c_pixelate_5': (0.20964, 0.39218),\n",
       " 'imagenet_c_shot_noise_1': (0.54358, 0.77536),\n",
       " 'imagenet_c_shot_noise_2': (0.4161, 0.65384),\n",
       " 'imagenet_c_shot_noise_3': (0.27702, 0.4932),\n",
       " 'imagenet_c_shot_noise_4': (0.12046, 0.25916),\n",
       " 'imagenet_c_shot_noise_5': (0.05758, 0.14338),\n",
       " 'imagenet_c_spatter_1': (0.6807, 0.88222),\n",
       " 'imagenet_c_spatter_2': (0.55276, 0.77744),\n",
       " 'imagenet_c_spatter_3': (0.4297, 0.6604),\n",
       " 'imagenet_c_spatter_4': (0.33964, 0.56574),\n",
       " 'imagenet_c_spatter_5': (0.22154, 0.41216),\n",
       " 'imagenet_c_zoom_blur_1': (0.48682, 0.72382),\n",
       " 'imagenet_c_zoom_blur_2': (0.39142, 0.62394),\n",
       " 'imagenet_c_zoom_blur_3': (0.31678, 0.53822),\n",
       " 'imagenet_c_zoom_blur_4': (0.25604, 0.46),\n",
       " 'imagenet_c_zoom_blur_5': (0.20196, 0.38204),\n",
       " 'imagenet_c_contrast_1': (0.58996, 0.82218),\n",
       " 'imagenet_c_contrast_2': (0.5122, 0.76188),\n",
       " 'imagenet_c_contrast_3': (0.37864, 0.62598),\n",
       " 'imagenet_c_contrast_4': (0.14414, 0.30702),\n",
       " 'imagenet_c_contrast_5': (0.03238, 0.08906),\n",
       " 'imagenet_c_elastic_transform_1': (0.63388, 0.8504),\n",
       " 'imagenet_c_elastic_transform_2': (0.40738, 0.61988),\n",
       " 'imagenet_c_elastic_transform_3': (0.49224, 0.7192),\n",
       " 'imagenet_c_elastic_transform_4': (0.35588, 0.57328),\n",
       " 'imagenet_c_elastic_transform_5': (0.12908, 0.27984),\n",
       " 'imagenet_c_frost_1': (0.55508, 0.7795),\n",
       " 'imagenet_c_frost_2': (0.37706, 0.60022),\n",
       " 'imagenet_c_frost_3': (0.26476, 0.46168),\n",
       " 'imagenet_c_frost_4': (0.24484, 0.43636),\n",
       " 'imagenet_c_frost_5': (0.18714, 0.35234),\n",
       " 'imagenet_c_gaussian_noise_1': (0.5666, 0.79394),\n",
       " 'imagenet_c_gaussian_noise_2': (0.45492, 0.69394),\n",
       " 'imagenet_c_gaussian_noise_3': (0.30058, 0.52232),\n",
       " 'imagenet_c_gaussian_noise_4': (0.15468, 0.31704),\n",
       " 'imagenet_c_gaussian_noise_5': (0.04984, 0.1261),\n",
       " 'imagenet_c_impulse_noise_1': (0.44462, 0.68038),\n",
       " 'imagenet_c_impulse_noise_2': (0.36212, 0.594),\n",
       " 'imagenet_c_impulse_noise_3': (0.29518, 0.51504),\n",
       " 'imagenet_c_impulse_noise_4': (0.15148, 0.31632),\n",
       " 'imagenet_c_impulse_noise_5': (0.0566, 0.14472),\n",
       " 'imagenet_c_motion_blur_1': (0.59656, 0.8222),\n",
       " 'imagenet_c_motion_blur_2': (0.48048, 0.71912),\n",
       " 'imagenet_c_motion_blur_3': (0.31222, 0.52966),\n",
       " 'imagenet_c_motion_blur_4': (0.17154, 0.33836),\n",
       " 'imagenet_c_motion_blur_5': (0.11288, 0.24448),\n",
       " 'imagenet_c_saturate_1': (0.6201, 0.84662),\n",
       " 'imagenet_c_saturate_2': (0.58384, 0.81874),\n",
       " 'imagenet_c_saturate_3': (0.66752, 0.8779),\n",
       " 'imagenet_c_saturate_4': (0.54994, 0.78792),\n",
       " 'imagenet_c_saturate_5': (0.4225, 0.6574),\n",
       " 'imagenet_c_snow_1': (0.48798, 0.7192),\n",
       " 'imagenet_c_snow_2': (0.27274, 0.47704),\n",
       " 'imagenet_c_snow_3': (0.29772, 0.50612),\n",
       " 'imagenet_c_snow_4': (0.19134, 0.36612),\n",
       " 'imagenet_c_snow_5': (0.14072, 0.28672),\n",
       " 'imagenet_c_speckle_noise_1': (0.56614, 0.79434),\n",
       " 'imagenet_c_speckle_noise_2': (0.50212, 0.73646),\n",
       " 'imagenet_c_speckle_noise_3': (0.3211, 0.54238),\n",
       " 'imagenet_c_speckle_noise_4': (0.23018, 0.42732),\n",
       " 'imagenet_c_speckle_noise_5': (0.14488, 0.30094)}"
      ]
     },
     "execution_count": 92,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "results = {}\n",
    "for key, (x, y) in tqdm.tqdm(eval_datasets.items()):\n",
    "    preds = x @ w.T + b[None, :]\n",
    "    preds_cls = np.argsort(preds, -1)\n",
    "    preds1 = preds_cls[:, -1:]\n",
    "    preds5 = preds_cls[:, -5:]\n",
    "    acc1 = (preds1 == y[:, None]).sum() / len(y)\n",
    "    acc5 = (preds5 == y[:, None]).sum() / len(y)\n",
    "    results[key] = (acc1, acc5)\n",
    "results"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py38",
   "language": "python",
   "name": "py38"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
